In this notebook we'll create GAN model to generate images with cats.
Based on the TDS post
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
import torch
import torch.nn as nn
import cv2
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline
!wget https://www.dropbox.com/s/329oy3cprlvn5vb/archive.zip
--2022-03-21 15:45:53-- https://www.dropbox.com/s/329oy3cprlvn5vb/archive.zip Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:601b:18::a27d:812 Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/329oy3cprlvn5vb/archive.zip [following] --2022-03-21 15:45:53-- https://www.dropbox.com/s/raw/329oy3cprlvn5vb/archive.zip Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com/cd/0/inline/Bh6945xywQy3yy9zJg9bPbyBl8khkb-eFBjbM0eLF2alypQmVDpJTFDy2rvX3dJlA_RxGL8wt1ERA_upDFuki19ngTXrSKELEI9VxkSltdNSYh572t_asySluxdijSMb_DElmjdpWHft9Z4ljpI6Ai9m3LXLf3HW8hSQH8lm0b7k5g/file# [following] --2022-03-21 15:45:53-- https://uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com/cd/0/inline/Bh6945xywQy3yy9zJg9bPbyBl8khkb-eFBjbM0eLF2alypQmVDpJTFDy2rvX3dJlA_RxGL8wt1ERA_upDFuki19ngTXrSKELEI9VxkSltdNSYh572t_asySluxdijSMb_DElmjdpWHft9Z4ljpI6Ai9m3LXLf3HW8hSQH8lm0b7k5g/file Resolving uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com (uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:6018:15::a27d:30f Connecting to uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com (uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com)|162.125.3.15|:443... connected. HTTP request sent, awaiting response... 302 Found Location: /cd/0/inline2/Bh4qfBCOUmNP1t4jrE2RnFvEOEQT_QKHHPQCyc0T0EDqm3lnxCt2wUt4bpdgEBbZndF9lJhAaic0hJA6mK8VtzIqiYQUCdCcEA0YOhXnbTAApL654MxxXkB3pl1IcgtqejSsAxZ0O15hnHcI7Dr0u4QfiKEFtHX5sB4owCNhGbHVOjL0Kr9t_865vO-gBWA2s6KJLVfcU7sSVUpEer8mglZ2kEnVxWaY_u6PBh1ZUGDIZFsE6MG0YR_qjYVGUg60-s7KZ81__1G17A8dHdMIVsfJuN_s3yX38Y-hXQ8Zaaz8WbdoyTk1X4EYBMBQim26EiUFPwp8hQuKTk-jp2X71AdLyI01yaQTTJXUlanxhVogHZ7fYORH9j4kLPQHxbH25jKJL4wprRn0FpRONNw4hMxdiMTNVz0VDzA_f1syXx4Trw/file [following] --2022-03-21 15:45:53-- https://uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com/cd/0/inline2/Bh4qfBCOUmNP1t4jrE2RnFvEOEQT_QKHHPQCyc0T0EDqm3lnxCt2wUt4bpdgEBbZndF9lJhAaic0hJA6mK8VtzIqiYQUCdCcEA0YOhXnbTAApL654MxxXkB3pl1IcgtqejSsAxZ0O15hnHcI7Dr0u4QfiKEFtHX5sB4owCNhGbHVOjL0Kr9t_865vO-gBWA2s6KJLVfcU7sSVUpEer8mglZ2kEnVxWaY_u6PBh1ZUGDIZFsE6MG0YR_qjYVGUg60-s7KZ81__1G17A8dHdMIVsfJuN_s3yX38Y-hXQ8Zaaz8WbdoyTk1X4EYBMBQim26EiUFPwp8hQuKTk-jp2X71AdLyI01yaQTTJXUlanxhVogHZ7fYORH9j4kLPQHxbH25jKJL4wprRn0FpRONNw4hMxdiMTNVz0VDzA_f1syXx4Trw/file Reusing existing connection to uc5c6cedfa347cc4a8b6343ce218.dl.dropboxusercontent.com:443. HTTP request sent, awaiting response... 200 OK Length: 100661714 (96M) [application/zip] Saving to: ‘archive.zip’ archive.zip 100%[===================>] 96.00M 49.6MB/s in 1.9s 2022-03-21 15:45:56 (49.6 MB/s) - ‘archive.zip’ saved [100661714/100661714]
from zipfile import ZipFile
with ZipFile('archive.zip', 'r') as zipObj:
# Extract all the contents of zip file in current directory
zipObj.extractall()
DATA_DIR = './cats/'
# set parameters of the transformed data
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
# As dataset is stored in the directory, we can create dataset
# as ImageFolder PyTorch object and set all the transformations here
train_ds = ImageFolder(DATA_DIR, transform=tt.Compose([
tt.ToTensor(),
tt.Normalize(*stats)]))
# Create PyTorch DataLoader object to produce batches
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
def denorm(img_tensors):
return img_tensors * stats[1][0] + stats[0][0]
# functions to plot images
def show_images(images, nmax=64):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
def show_batch(dl, nmax=64):
for images, _ in dl:
show_images(images, nmax)
break
show_batch(train_dl)
Discriminator - model to classify images as real and generated.
Generator - decoder producing images from the vector of normal distribution.
Note: use ConvTranspose layers as inversion of usual convolutions.
# Utils functions for GPU usage of neural networks
def get_default_device():
"""Pick GPU if available, else CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
"""Move tensor(s) to chosen device"""
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
class DeviceDataLoader():
"""Wrap a dataloader to move data to a device"""
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
"""Yield a batch of data after moving it to device"""
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
"""Number of batches"""
return len(self.dl)
device = get_default_device()
device
device(type='cuda')
train_dl = DeviceDataLoader(train_dl, device)
discriminator = nn.Sequential(
# in: 3 x 64 x 64
nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.Dropout(),
nn.LeakyReLU(0.2, inplace=True),
# out: 64 x 32 x 32
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# out: 128 x 16 x 16
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# out: 256 x 8 x 8
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# out: 512 x 4 x 4
# added 1. Conv-BN blocks
nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# added 2. Pooling
nn.AvgPool2d(kernel_size=(2,2),stride=2,padding=1),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=1, bias=False),
# out: 1 x 1 x 1
nn.Flatten(),
nn.Sigmoid())
discriminator = to_device(discriminator, device)
#latent_size = 128
# Increase latent_size
latent_size=256
generator = nn.Sequential(
# in: latent_size x 1 x 1
nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.Dropout(),
nn.LeakyReLU(0.2, inplace=True),
# out: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# out: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# out: 128 x 16 x 16
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# Try use ELU or LeakyReLU
# out: 64 x 32 x 32
# nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(64),
# nn.ReLU(True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# out: 3 x 64 x 64
)
xb = torch.randn(batch_size, latent_size, 1, 1) # random latent tensors
fake_images = generator(xb)
print(fake_images.shape)
show_images(fake_images)
torch.Size([128, 3, 64, 64])
generator = to_device(generator, device)
print(discriminator,generator)
Sequential( (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): Dropout(p=0.5, inplace=False) (3): LeakyReLU(negative_slope=0.2, inplace=True) (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): LeakyReLU(negative_slope=0.2, inplace=True) (7): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (9): LeakyReLU(negative_slope=0.2, inplace=True) (10): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (11): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (12): LeakyReLU(negative_slope=0.2, inplace=True) (13): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (14): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (15): LeakyReLU(negative_slope=0.2, inplace=True) (16): AvgPool2d(kernel_size=(2, 2), stride=2, padding=1) (17): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (18): LeakyReLU(negative_slope=0.2, inplace=True) (19): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False) (20): Flatten(start_dim=1, end_dim=-1) (21): Sigmoid() ) Sequential( (0): ConvTranspose2d(256, 512, kernel_size=(4, 4), stride=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): Dropout(p=0.5, inplace=False) (3): LeakyReLU(negative_slope=0.2, inplace=True) (4): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): LeakyReLU(negative_slope=0.2, inplace=True) (7): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (9): LeakyReLU(negative_slope=0.2, inplace=True) (10): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (12): LeakyReLU(negative_slope=0.2, inplace=True) (13): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (14): Tanh() )
def train_discriminator(real_images, opt_d):
# Clear discriminator gradients
opt_d.zero_grad()
# Pass real images through discriminator
real_preds = discriminator(real_images)
real_targets = torch.ones(real_images.size(0), 1, device=device)
real_loss = F.binary_cross_entropy(real_preds, real_targets)
real_score = torch.mean(real_preds).item()
# Generate fake images
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = generator(latent)
# Pass fake images through discriminator
fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
fake_preds = discriminator(fake_images)
fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
fake_score = torch.mean(fake_preds).item()
# Update discriminator weights
loss = real_loss + fake_loss
loss.backward()
opt_d.step()
return loss.item(), real_score, fake_score
def train_generator(opt_g):
# Clear generator gradients
opt_g.zero_grad()
# Generate fake images
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = generator(latent)
# Try to fool the discriminator
preds = discriminator(fake_images)
targets = torch.ones(batch_size, 1, device=device)
loss = F.binary_cross_entropy(preds, targets)
# Update generator weights
loss.backward()
opt_g.step()
return loss.item()
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)
def save_samples(index, latent_tensors, show=True):
fake_images = generator(latent_tensors)
fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
print('Saving', fake_fname)
if show:
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(denorm(fake_images).cpu().detach(), nrow=8).permute(1, 2, 0))
plt.show()
fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)
save_samples(0, fixed_latent)
Saving generated-images-0000.png
def fit(epochs, lr, start_idx=1):
torch.cuda.empty_cache()
# Losses & scores
losses_g = []
losses_d = []
real_scores = []
fake_scores = []
# Create optimizers
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_images, _ in tqdm(train_dl):
# Train discriminator
loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
# Train generator
loss_g = train_generator(opt_g)
# Record losses & scores
losses_g.append(loss_g)
losses_d.append(loss_d)
real_scores.append(real_score)
fake_scores.append(fake_score)
# Log losses & scores (last batch)
print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
# Save generated images
save_samples(epoch+start_idx, fixed_latent, show=True)
return losses_g, losses_d, real_scores, fake_scores
lr = 0.0002
epochs = 60
history = fit(epochs, lr)
losses_g, losses_d, real_scores, fake_scores = history
from IPython.display import Image
Image('./generated/generated-images-0060.png')
vid_fname = 'gans_training.mp4'
files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'generated' in f]
files.sort()
out = cv2.VideoWriter(vid_fname,cv2.VideoWriter_fourcc(*'MP4V'), 1, (530,530))
[out.write(cv2.imread(fname)) for fname in files]
out.release()
plt.plot(losses_d, '-')
plt.plot(losses_g, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses');
plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('epoch')
plt.ylabel('score')
plt.legend(['Real', 'Fake'])
plt.title('Scores');